[magpietts] added multiple validation dataloaders and log metrics per val data#15348
[magpietts] added multiple validation dataloaders and log metrics per val data#15348XuesongYang wants to merge 8 commits intoNVIDIA-NeMo:mainfrom
Conversation
861e8b3 to
fae5fcb
Compare
…VIDIA-NeMo#15189) * added multiple validation dataloaders and log metrics per val data. * Apply suggestion from @XuesongYang Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @Copilot Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @Copilot Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> * Apply suggestion from @Copilot Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> --------- Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com> Co-authored-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
…ation to on_validation_epoch_end. Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
fae5fcb to
c9cc855
Compare
There was a problem hiding this comment.
Pull request overview
Adds support for validating MagpieTTS on multiple datasets (multiple validation dataloaders) while improving how media artifacts (audio + attention visualizations) are prepared and logged to W&B/TensorBoard, and updates the example Lhotse config to the new dataset configuration structure.
Changes:
- Refactors validation media logging by separating data preparation (numpy arrays) from logger-specific emission (W&B/TB objects).
- Adds multi-dataloader validation support, including per-dataloader metric aggregation and an averaged validation loss for checkpointing.
- Updates the MagpieTTS Lhotse example config to remove the
dataset:nesting and introduce avalidation_ds.datasetslist format.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| nemo/collections/tts/models/magpietts.py | Implements multi-validation-dataloader handling, refactors media logging, and adjusts Lhotse dataloader config expectations. |
| examples/tts/conf/magpietts/magpietts_lhotse.yaml | Updates example configuration to match the new train/validation dataset config structure and multi-val datasets list format. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…exists in val ds config Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
|
|
||
| return log_dict | ||
|
|
||
| def on_validation_epoch_end(self): |
There was a problem hiding this comment.
This method is too long. I would suggest creating two separate methods like:
if is_multi_dataloader:
self._process_multi_dataloader_validation()
else:
self._process_single_dataloader_validation()
Instead of maintaining 4 lists
all_losses, all_codebook_losses, all_alignment_losses, all_aligner_encoder_losses, all_local_transformer_losses switch to a dict. Keeps it clean, something like:
for metric_name, metric_value in dataloader_logs.items():
self.log(f"Loss:{dataloader_prefix}/{metric_name}", metric_value, prog_bar=False, sync_dist=True)
aggregated_metrics.setdefault(metric_name, []).append(metric_value)
There was a problem hiding this comment.
-
Using a dict instead of 4 separate lists: This is a good suggestion! will make changes.
-
Splitting into two methods: Reasonable for readability, though the method is currently ~115 lines which is manageable. The split would reduce cognitive load. After refactoring by item 1, 115 lines are reduced to 50 lines.
| ) | ||
|
|
||
| # Compute required metrics | ||
| val_loss = collect_required_metric(outputs, 'val_loss') |
There was a problem hiding this comment.
Just a suggestion. I see a the metrics being hard coded everywhere and loosely used. Instead we should define them in a constant Enum.
There was a problem hiding this comment.
I may not follow what are the "hard coded" metrics in this context. Could you pls elaborate?
There was a problem hiding this comment.
a second thought on your suggestion, maybe we can get something like below,
from enum import Enum
class ValidationMetric(Enum):
"""Validation metric keys used in validation_step outputs and logging."""
LOSS = 'loss'
CODEBOOK_LOSS = 'codebook_loss'
ALIGNMENT_LOSS = 'alignment_loss'
ALIGNER_ENCODER_LOSS = 'aligner_encoder_loss'
LOCAL_TRANSFORMER_LOSS = 'local_transformer_loss'
# Usage:
val_loss = collect_required_metric(outputs, f'val_{ValidationMetric.LOSS.value}')
log_dict = {
ValidationMetric.LOSS.value: val_loss,
ValidationMetric.CODEBOOK_LOSS.value: val_codebook_loss,
}If my understanding aligns with yours, I would leave it for future cleanup because we are not adding new metrics frequently. It is a good suggestion, and we can track as a follow-up imrpovement.
There was a problem hiding this comment.
Sure, let's take it up for future cleanup.
| val_codebook_loss = collect_required_metric(outputs, 'val_codebook_loss') | ||
|
|
||
| # Compute optional metrics | ||
| val_alignment_loss = collect_optional_metric(outputs, 'val_alignment_loss') |
There was a problem hiding this comment.
Can something like this be done?
for metric in VAL_METRIC_LIST:
log_dict[metric] = collect_optional_metric(outputs, metric)
We need to check for Nones.
There was a problem hiding this comment.
Nones are already managed in validation_step. No need to check again here. Besides, collect_optional_metric would return None if any metric does not exist in outputs.
There was a problem hiding this comment.
Can something like this be done?
for metric in VAL_METRIC_LIST: log_dict[metric] = collect_optional_metric(outputs, metric)We need to check for Nones.
we only have two required losses, so i believe no need to wrap up as a for loop. Similar to optional losses, we only have 3. So I would leave it if we want to access many more losses in the future.
There was a problem hiding this comment.
I meant do something like the following. So if we in the future need to add any additional loss we do not need to touch this code, also makes it much cleaner.
VAL_METRIC_LIST = ['val_loss', 'val_codebook_loss', 'val_alignment_loss', ...]
# This is where a global list/Enum of metrics would help, so we can just iterate over them.
# But for now a local list would be fine.
for metric in VAL_METRIC_LIST:
metric_value = collect_optional_metric(outputs, metric)
if metric_value is not None:
log_dict[metric] = metric_value
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
|
@subhankar-ghosh @blisc I addressed all your suggestions. Please have a look. |
| 'context_audios': context_audios, | ||
| } | ||
|
|
||
| def _log_media_to_wandb_and_tb(self, media_data: ValidationMediaData, global_step: int) -> None: |
There was a problem hiding this comment.
I am not a fan of passing dataclasses to functions especially if you do not re-use them. It hides the underlying datatypes of the class and it's not intuitive from the docstring what you are supposed to pass to this function. Is there an argument for why you want to do it this way?
Summary
Details
where, the yaml config for validation datasets looks like below, which is apt to generalize to multiple languages datasets.
wandb log see here: https://wandb.ai/aiapps/debug_magpieTTS_EN_2509/runs/bqerks4y?nw=nwuserxuesong_yang
The model yaml config that were previously under
train_ds.datasetare now directly under 'train_ds'.The configuration structure for
validation_dschanges: